libs

require(mice)
## Warning: package 'mice' was built under R version 4.1.3
require(lattice)
library(tidyverse)

library(rMIDAS)
# set_python_env(python ="/opt/anaconda3/bin/python")
set_python_env(x ="C:\\ProgramData\\Anaconda3\\",type = "conda")
## [1] TRUE
library(ggplot2)
library(gridExtra)
library("GGally")

library(gdata)

Data

data <- read.csv("https://raw.githubusercontent.com/MIDASverse/MIDASpy/master/Examples/adult_data.csv",
                    # colClasses=c("NULL",NA,NA,NA),
                    row.names = 1)[1:3000, ]
head(data)
##   age        workclass fnlwgt education education_num     marital_status
## 0  39        State-gov  77516 Bachelors            13      Never-married
## 1  50 Self-emp-not-inc  83311 Bachelors            13 Married-civ-spouse
## 2  38          Private 215646   HS-grad             9           Divorced
## 3  53          Private 234721      11th             7 Married-civ-spouse
## 4  28          Private 338409 Bachelors            13 Married-civ-spouse
## 5  37          Private 284582   Masters            14 Married-civ-spouse
##          occupation  relationship  race    sex capital_gain capital_loss
## 0      Adm-clerical Not-in-family White   Male         2174            0
## 1   Exec-managerial       Husband White   Male            0            0
## 2 Handlers-cleaners Not-in-family White   Male            0            0
## 3 Handlers-cleaners       Husband Black   Male            0            0
## 4    Prof-specialty          Wife Black Female            0            0
## 5   Exec-managerial          Wife White Female            0            0
##   hours_per_week native_country class_labels
## 0             40  United-States        <=50K
## 1             13  United-States        <=50K
## 2             40  United-States        <=50K
## 3             40  United-States        <=50K
## 4             40           Cuba        <=50K
## 5             40  United-States        <=50K

Data Explore

adult_cat <- c('workclass','marital_status','relationship','race','education','occupation','native_country')
adult_bin <- c('sex','class_labels')
adult_num <- c('age','fnlwgt','education_num','capital_gain','capital_loss','hours_per_week')
for(col in c(adult_bin,adult_cat)){
  data[,col] <- as.factor(data[,col])
}


# qplot(data$workclass)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$marital_status)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$relationship)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$race)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$education)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$occupation)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))
# qplot(data$native_country)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))

Imputation with MICE

#### Create Missing Data
miss_data <- add_missingness(data, prop = 0.2)
miss_data <- as.data.frame(miss_data)
# miss_index <- which(is.na(miss_data[,"reg"]))
# view miss number of miss data by coluemns
print(sapply(miss_data, function(x) sum(is.na(x))))
##            age      workclass         fnlwgt      education  education_num 
##            626            592            571            627            593 
## marital_status     occupation   relationship           race            sex 
##            639            580            601            606            611 
##   capital_gain   capital_loss hours_per_week native_country   class_labels 
##            597            578            595            578            591
#### Imputing Data with missRanger
library(missRanger)

impt_ranger_data <- replicate(
  10, 
  as.data.frame(missRanger(miss_data, verbose = 0, num.trees = 100)), 
  simplify = FALSE
)
for (i in 1:10){
  for(cat in c(adult_bin,adult_cat)){
    impt_ranger_data[[i]][,cat] <- round(impt_ranger_data[[i]][,cat])
  }
  
}

Imputing Data with MICE

imp <-  mice(miss_data, print=F)
meth <- imp$meth
meth[adult_cat] <- "cart"
meth[adult_bin] <- 'rf'
meth[adult_num] <- "rf"

imp <- mice(miss_data, m=10, method = meth, print=F)
imp20  <-  mice.mids(imp, maxit=15, print=F)
impt_mice_data <- list()

for (i in 1:10){
  impt_mice <- mice::complete(imp20,action=i)
  impt_mice_data <- append(impt_mice_data,list(impt_mice))
}

Imputing Data with RMIDAS

data_conv <- rMIDAS::convert(miss_data, 
                             bin_cols = adult_bin, 
                             cat_cols = adult_cat,
                             minmax_scale = TRUE)

# Train the model for 20 epochs
rmidas_train <- rMIDAS::train(data_conv,
                              training_epochs = 20,
                              layer_structure = c(128,128),
                              input_drop = 0.75,
                              seed = 89)

# Generate 10 imputed datasets
impt_rmidas_data <- rMIDAS::complete(rmidas_train, m = 10,fast = TRUE)

plot results

create_compare_data <- function(df,miss_df,impt_df_list,col,m=10,
                                method="mice", sp_impt="sex"){
  # refer:https://cran.r-project.org/web/packages/gdata/vignettes/mapLevels.pdf
  map <- mapLevels(x=factor(df$sex))
  # we only need to compare the missing values
  
  # df$sex <- as.factor(as.numeric(df$sex))
  miss_df <- as.data.frame(miss_df)
  miss_index <- which(is.na(miss_df[,col]))
  
  na_count <- apply(miss_df[miss_index,], 1, function(x) sum(is.na(x)))
  
  df <- df[miss_index,]
  df["source"] <- rep("True",length(miss_index))
  df$na_count <- rep("True(0 na)",length(miss_index))
  for(i in 1:m){
    df2 <- impt_df_list[[i]]
    df2 <- df2[miss_index,]
    df2["source"] <- rep(method,length(miss_index))
    df2$na_count <-  na_count
    if(sp_impt=="method"){
      df2["source"] <- rep(paste(method,i,sep = "-"),length(miss_index))
    }
    df <- rbind(df2,df)
  }
  # convert integer to boys and girls
  # int <- as.integer(df$sex)
  # mapLevels(x=int) <- map
  # df$sex <- int
  if (sp_impt=="sex"){
    df$source <- apply( df[ ,c("sex","source")] , 1 , paste , collapse = "-" )
  }
  
  # print(head(df))
  df
}




library(scales)
library(caret)
library(gdata)

ggplotConfusionMatrix <- function(m, col_names, method_name){
  #https://stackoverflow.com/questions/51410405/ggplot2-confusion-matrix-geom-text-labeling
  mytitle <- paste(method_name,"Accuracy", percent_format()(m$overall[1]),
                   "Kappa", percent_format()(m$overall[2]))
  data_c <-  mutate(group_by(as.data.frame(m$table), Reference ), percentage = 
                      percent(Freq/sum(Freq)))
  p <-
    ggplot(data = data_c,
           aes(x = Reference, y = Prediction)) +
    geom_tile(aes(fill = Freq), colour = "white") +
    scale_fill_gradient(low = "white", high = "green") +
    geom_text(aes(x = Reference, y = Prediction, label = percentage)) +
    scale_x_discrete(labels=col_names,guide = guide_axis(angle = 45))+
    scale_y_discrete(labels=col_names)+
    # theme(legend.position = "none") +
    ggtitle(mytitle)
  return(p)
}



plot_confusion_matrix <- function(impt_data_list, data, miss_df, col,method_name, m=10){
  
  miss_df <-  as.data.frame(miss_df)
  miss_index <- which(is.na(miss_df[,col]))
  
  pred_values <- c()
  for (i in 1:m){
    pred <- impt_data_list[[i]]
    pred <- pred[,col]
    pred_values <- c(pred_values,pred[miss_index])
  }
  
  
  true_labels <- data[miss_index,col]
  
  pred_values <- as.factor(pred_values)
  
  true_labels <- as.factor(as.numeric(as.factor(rep(true_labels,m))))

  pred_values <- factor(pred_values,levels = levels(true_labels))
  true_labels <- factor(true_labels,levels = levels(true_labels))
  # print(pred_values)
  # print(true_labels)
  
  
  cfm <- confusionMatrix(true_labels,pred_values)
  
  map <- mapLevels(x=as.factor(data[,col]))
  
  ggplotConfusionMatrix(cfm,names(map),method_name)
  
}

marital_status

qplot(data$marital_status)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))

plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "marital_status",method_name="ranger")

plot_confusion_matrix(impt_mice_data,data,miss_data,col ="marital_status",method_name="mice")

plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="marital_status",method_name="rmidas")

workclass

qplot(data$workclass)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))

plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "workclass",method_name="ranger")
## Warning: Removed 14 rows containing missing values (geom_text).

plot_confusion_matrix(impt_mice_data,data,miss_data,col ="workclass",method_name="mice")

plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="workclass",method_name="rmidas")

relationship

qplot(data$relationship)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))

plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "relationship",method_name="ranger")
## Warning: Removed 6 rows containing missing values (geom_text).

plot_confusion_matrix(impt_mice_data,data,miss_data,col ="relationship",method_name="mice")

plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="relationship",method_name="rmidas")

## race

qplot(data$race)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))

plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "race",method_name="ranger")
## Warning: Removed 10 rows containing missing values (geom_text).

plot_confusion_matrix(impt_mice_data,data,miss_data,col ="race",method_name="mice")

plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="race",method_name="rmidas")

## education

qplot(data$education)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))

plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "education",method_name="ranger")

plot_confusion_matrix(impt_mice_data,data,miss_data,col ="education",method_name="mice")

plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="education",method_name="rmidas")
## Warning: Removed 32 rows containing missing values (geom_text).

## occupation

qplot(data$occupation)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))

plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "occupation",method_name="ranger")
## Warning: Removed 42 rows containing missing values (geom_text).

plot_confusion_matrix(impt_mice_data,data,miss_data,col ="occupation",method_name="mice")

plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="occupation",method_name="rmidas")

## native_country

qplot(data$native_country)+theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust=1))

plot_confusion_matrix(impt_ranger_data,data,miss_data,col = "native_country",method_name="ranger")
## Warning: Removed 338 rows containing missing values (geom_text).

plot_confusion_matrix(impt_mice_data,data,miss_data,col ="native_country",method_name="mice")
## Warning: Removed 104 rows containing missing values (geom_text).

plot_confusion_matrix(impt_rmidas_data,data,miss_data,col ="native_country",method_name="rmidas")
## Warning: Removed 390 rows containing missing values (geom_text).

MICE: hours_per_week

MICE:compare the imputed datasets with orignal dataset

df_mice_wgt <- create_compare_data(data,miss_data,impt_mice_data,col = "hours_per_week",method = "mice",sp_impt="method")
ggplot(df_mice_wgt, aes(age,hours_per_week, colour = source))+geom_point(alpha=0.4)+stat_smooth()

MICE:compare split with Sex

df_mice_wgt <- create_compare_data(data,miss_data,impt_mice_data,col = "hours_per_week",method = "mice",sp_impt="sex")
ggplot(df_mice_wgt, aes(age,hours_per_week, colour = source))+geom_point(alpha=0.4)+stat_smooth()

compare miss to true data:age

miss_index <- which(is.na(miss_data$age))
for (i in 1:10){
  sex <- factor(data$sex[miss_index])
  g1 <- qplot(data$age[miss_index],impt_mice_data[[3]]$age[miss_index],col=sex)+stat_smooth()+ylim(-5,22)+
    ylab("mice age") + xlab("data age")+theme(legend.position = "top")
  
  g2 <- qplot(data$age[miss_index],impt_ranger_data[[3]]$age[miss_index],col=sex)+stat_smooth()+ylim(-5,22)+
    ylab("ranger age") + xlab("data age")+theme(legend.position = "top")
  
  g3 <- qplot(data$age[miss_index],impt_rmidas_data[[3]]$age[miss_index],col=sex)+stat_smooth()+ylim(-5,22)+
    ylab("midas age") + xlab("data age")+theme(legend.position = "top")
  grid.arrange(g1, g2,g3, ncol=3)
  
}